# Direct outputs not eligible for public release

# GLOBAL SETTINGS --------------------------------------------------------------

options(
    scipen = 999,
    digits = 16,
    max.print = .Machine$integer.max,
    show.error.locations = TRUE,
    warn = 1
)

RNGkind("L'Ecuyer-CMRG")
seed <- 818675309L
set.seed(seed) # setting main seed

# PACKAGES ---------------------------------------------------------------------
library(data.table)
library(zoo)
library(multiwayvcov)
library(lmtest)
library(checkmate)
library(futile.logger)
library(lfe)

library(remotes)
remotes::install_github("setzler/eventStudy/eventStudy")
library(eventStudy) # for DiD-IV

# PACKAGE SETTINGS -------------------------------------------------------------

# data.table
setDTthreads(threads = 1L)
options(datatable.print.class = TRUE, datatable.print.keys = TRUE)
# so that printing the data.table also shows the variable type on top

# BEGIN FILE -------------------------------------------------------------------

# Just in case something masks 'source()', use extra deliberate base::source().
base::source("~/code/0-utility-functions/wald_es.R", local = TRUE)
outcome <- "db_w2_wages"

did_laterwinners_coefs <-
    readRDS(
        sprintf("~/estimation-output/event_study_estimates_%s.rds", outcome)
    )

did_laterwinners_coefs <- setDT(did_laterwinners_coefs[[1]])

did_laterwinners_coefs <-
    did_laterwinners_coefs[
        between(ref_event_time, -7L, 5L, incbounds = TRUE) &
        ref_onset_time == "Cohort-Weighted" &
        model == "reduced_form" &
        rn == "att",
        .(
            ref_event_time,
            estimate,
            cluster_se
        )
    ]

# Introduce a omitted_event_time row
did_laterwinners_coefs <-
    rbindlist(
        list(
            did_laterwinners_coefs,
            data.table(
                ref_event_time = -2L,
                estimate = 0,
                cluster_se = 0
            )
        ),
        use.names = TRUE
    )
setorderv(did_laterwinners_coefs, "ref_event_time")

# Work backwards to a levels graph; since the estimates are DiD, will pin down
# following quantities in data, and back out fourth with DiD estimate
# Y_{w-s}(0)
# Y_{w-s}(1)
# Y_{w+ℓ}(0)

# Read in all of the annual files and combine
# - use "_slim" files as they aren't as wide / more memory-efficient
file_list <-
    list.files(
        path = "~/population-annual-data",
        pattern = "lottery\\_data\\_slim",
        full.names = TRUE
    )

lottery_panel <- list()
ii <- 0L
for (ff in file_list) {
    ii <- ii + 1L
    lottery_panel[[ii]] <- readRDS(file = ff)
    lottery_panel[[ii]] <- setDT(lottery_panel[[ii]])
}
ii <- NULL
ff <- NULL
file_list <- NULL
rm(ii, ff, file_list)
lottery_panel <- rbindlist(l = lottery_panel, use.names = TRUE)

setorderv(lottery_panel, c("tin", "tax_yr"))

# For all outcomes, replace with zero if missing
lottery_panel[is.na(get(outcome)), (outcome) := 0]

# Each treated cohort will be restricted to age 21-64 in year 0 and
# only control units age 21-64 in the same calendar year will be used
lottery_panel[
    ,
    age_case := as.logical(between(age, 21, 64, incbounds = TRUE))
]

vars_to_pass_along <-
    setdiff(
        colnames(lottery_panel), c("tin", "tax_yr", "win_yr", outcome)
    )
es_dt <-
    ES_clean_data(
        long_data = copy(lottery_panel),
        outcomevar = outcome,
        unit_var = "tin",
        cal_time_var = "tax_yr",
        onset_time_var = "win_yr",
        cluster_vars = "tin",
        discrete_covars = copy(vars_to_pass_along),
        omitted_event_time = -2L,
        control_subset_var = "age_case",
        control_subset_event_time = 0L,
        treated_subset_var = "age_case",
        treated_subset_event_time = 0L
    )

# Demean by age within each reference cohort (matches the DiD we do)
es_treatcontrol_means_rf_nocontrols <-
    es_dt[
        ,
        .(estimate = mean(db_w2_wages)),
        by = .(ref_onset_time, ref_event_time, treated)
    ][
        order(ref_onset_time, ref_event_time, treated)
    ]

es_age_means <-
    es_dt[
        ,
        .(age_mean = mean(db_w2_wages)),
        by = .(ref_onset_time, age)
    ][
        order(ref_onset_time, age)
    ]
es_dt <-
    merge(
        es_dt,
        es_age_means,
        by = c("age", "ref_onset_time"),
        all = TRUE,
        sort = FALSE
    )
es_dt[, db_w2_wages_age_demean := db_w2_wages - age_mean]

es_treatcontrol_means_rf_controls <-
    es_dt[
    ,
    .(estimate = mean(db_w2_wages_age_demean)),
    by = .(ref_onset_time, ref_event_time, treated)
    ][
        order(ref_onset_time, ref_event_time, treated)
    ]

# Normalize to match omitted_event_time of the uncontrolled series
normalization <- list()
ii <- 0
for (cc in sort(unique(es_dt$ref_onset_time))) {

    print(cc)
    ii <- ii + 1

    treat_factor_temp <-
        es_treatcontrol_means_rf_nocontrols[
            treated == 1 &
                ref_event_time == -2L &
                ref_onset_time == cc
        ]$estimate -
        es_treatcontrol_means_rf_controls[
            treated == 1 &
                ref_event_time == -2L &
                ref_onset_time == cc
        ]$estimate

    control_factor_temp <-
        es_treatcontrol_means_rf_nocontrols[
            treated == 0 &
                ref_event_time == -2L &
                ref_onset_time == cc
        ]$estimate -
        es_treatcontrol_means_rf_controls[
            treated == 0 &
                ref_event_time == -2L &
                ref_onset_time == cc
        ]$estimate

    normalization[[ii]] <-
        data.table(
            ref_onset_time = cc,
            treat_factor = treat_factor_temp,
            control_factor = control_factor_temp
        )

    # Housekeeping
    treat_factor_temp <- NULL
    control_factor_temp <- NULL
    rm(treat_factor_temp, control_factor_temp)
}
ii <- NULL
cc <- NULL
rm(ii, cc)

normalization <- rbindlist(normalization, use.names = TRUE)

es_treatcontrol_means_rf_controls <-
    merge(
        es_treatcontrol_means_rf_controls,
        normalization,
        by = "ref_onset_time",
        all.x = TRUE,
        sort = FALSE
    )

es_treatcontrol_means_rf_controls[
    treated == 1L, estimate := estimate + treat_factor
]
es_treatcontrol_means_rf_controls[
    treated == 0L, estimate := estimate + control_factor
]

did_laterwinners_levels <- copy(es_treatcontrol_means_rf_controls)
did_laterwinners_levels[, year := ref_onset_time + ref_event_time]
did_laterwinners_levels <-
    did_laterwinners_levels[between(ref_event_time, -7L, 5L, incbounds = TRUE)]

did_laterwinners_levels[
    treated == 1L,
    treat_label := "Treated (Current Winners)"
]
did_laterwinners_levels[treated == 0L, treat_label := "Control (Later Winners)"]

did_laterwinners_levels[
    ,
    treat_label :=
        factor(
            treat_label,
            levels = c("Treated (Current Winners)", "Control (Later Winners)")
        )
]

# Now need to introduce cohort weights (to match our DiD)
# One for the treated and one for the controls
# - exclude -2L as that omitted time is re-used, so will overcount

cohort_weights <-
    es_dt[
        treated == 1L & (ref_event_time != -2L),
        .(catt_treated_unique_units = .N),
        by = .(ref_onset_time, ref_event_time)
    ][
        order(
            ref_onset_time,
            ref_event_time
        )
    ]
cohort_weights[
    ,
    cohort_weight := catt_treated_unique_units / sum(catt_treated_unique_units),
    by = .(ref_event_time)
]
cohort_weights[, catt_treated_unique_units := NULL]

# Above explicitly missing -2L, just replace with data from -1L
# - everyone has -2L and -1L, will just use the -1L weights
temp <- cohort_weights[ref_event_time == -1L]
temp[, ref_event_time := -2L]
cohort_weights <- rbindlist(list(cohort_weights, temp), use.names = TRUE)
temp <- NULL
rm(temp)

did_laterwinners_levels <-
    merge(
        did_laterwinners_levels,
        cohort_weights,
        by = c("ref_onset_time", "ref_event_time"),
        all.x = TRUE,
        sort = FALSE
    )

# Reduce to cohort-weighted means
# - keep 'treat_label' around for graph text
did_laterwinners_levels <-
    did_laterwinners_levels[
        ,
        .(estimate = weighted.mean(x = estimate, w = cohort_weight)),
        by = .(ref_event_time, treated, treat_label)
    ]

# Pin illustrative values for figure (t = treated, c = control)
# - pre_line_t_val: value in omitted year, -2, for the treated (current winners)
# - pre_line_c_val: value in omitted year, -2, for the control (later winners)
# - post_line_t_val: mean in the first five years post win for the treated
# - post_line_c_val: mean in the first five years post win for the control
pre_line_t_val <-
    did_laterwinners_levels[ref_event_time == -2L & treated == 1L]$estimate
pre_line_c_val <-
    did_laterwinners_levels[ref_event_time == -2L & treated == 0L]$estimate

# Now can merge in the DiD, and then undo it in the data.table
setnames(did_laterwinners_coefs, "estimate", "did")
did_laterwinners_levels <-
    merge(
        did_laterwinners_levels,
        did_laterwinners_coefs,
        by = "ref_event_time",
        all.x = TRUE,
        sort = FALSE
    )
setnames(did_laterwinners_coefs, "did", "estimate")

did_laterwinners_levels[
    ,
    pre_diff := pre_line_t_val - pre_line_c_val
]

did_laterwinners_levels[
    ,
    post_c_level := sum(as.integer(treated == 0L) * estimate),
    by = .(ref_event_time)
]
did_laterwinners_levels[
    treated == 1L,
    estimate := did + pre_diff + post_c_level
]

post_line_t_val <-
    mean(
        did_laterwinners_levels[
            between(ref_event_time, 1L, 5L, incbounds = TRUE) &
            treated == 1L
        ]$estimate
    )
post_line_c_val <-
    mean(
        did_laterwinners_levels[
            between(ref_event_time, 1L, 5L, incbounds = TRUE) &
            treated == 0L
        ]$estimate
    )

did_laterwinners_levels[ref_event_time <= 0L, pre_line_t := pre_line_t_val]
did_laterwinners_levels[ref_event_time <= 0L, pre_line_c := pre_line_c_val]
did_laterwinners_levels[ref_event_time >= 0L, post_line_t := post_line_t_val]
did_laterwinners_levels[ref_event_time >= 0L, post_line_c := post_line_c_val]

saveRDS(
    did_laterwinners_levels,
    "~/estimation-output/estimates-for-figure-3-1-c.rds"
)

# Housekeeping
outcome <- NULL
did_laterwinners_coefs <- NULL
lottery_panel <- NULL
vars_to_pass_along <- NULL
es_dt <- NULL
es_treatcontrol_means_rf_nocontrols <- NULL
es_age_means <- NULL
es_treatcontrol_means_rf_controls <- NULL
normalization <- NULL
cohort_weights <- NULL
pre_line_t_val <- NULL
pre_line_c_val <- NULL
post_line_t_val <- NULL
post_line_c_val <- NULL
did_laterwinners_levels <- NULL
rm(
    outcome,
    did_laterwinners_coefs,
    lottery_panel,
    vars_to_pass_along,
    es_dt,
    es_treatcontrol_means_rf_nocontrols,
    es_age_means,
    es_treatcontrol_means_rf_controls,
    normalization,
    cohort_weights,
    pre_line_t_val,
    pre_line_c_val,
    post_line_t_val,
    post_line_c_val,
    did_laterwinners_levels
)